!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief  Methods that handle helium-solvent and helium-helium interactions
!> \author Lukasz Walewski
!> \date   2009-06-10
! **************************************************************************************************
MODULE helium_interactions

   USE cp_log_handling,                 ONLY: cp_get_default_logger,&
                                              cp_logger_type
   USE helium_common,                   ONLY: helium_eval_chain,&
                                              helium_eval_expansion,&
                                              helium_pbc,&
                                              helium_spline
   USE helium_nnp,                      ONLY: helium_nnp_print
   USE helium_types,                    ONLY: e_id_interact,&
                                              e_id_kinetic,&
                                              e_id_potential,&
                                              e_id_thermo,&
                                              e_id_total,&
                                              e_id_virial,&
                                              helium_solvent_p_type,&
                                              helium_solvent_type
   USE input_constants,                 ONLY: helium_sampling_worm,&
                                              helium_solute_intpot_mwater,&
                                              helium_solute_intpot_nnp,&
                                              helium_solute_intpot_none
   USE input_section_types,             ONLY: section_vals_get_subs_vals,&
                                              section_vals_type
   USE kinds,                           ONLY: dp
   USE nnp_acsf,                        ONLY: nnp_calc_acsf
   USE nnp_environment_types,           ONLY: nnp_type
   USE nnp_model,                       ONLY: nnp_gradients,&
                                              nnp_predict
   USE physcon,                         ONLY: angstrom,&
                                              kelvin
   USE pint_types,                      ONLY: pint_env_type
   USE splines_types,                   ONLY: spline_data_type
#include "../base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .TRUE.
   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'helium_interactions'

   PUBLIC :: helium_calc_energy
   PUBLIC :: helium_total_link_action
   PUBLIC :: helium_total_pair_action
   PUBLIC :: helium_total_inter_action
   PUBLIC :: helium_solute_e_f
   PUBLIC :: helium_bead_solute_e_f
   PUBLIC :: helium_intpot_scan
   PUBLIC :: helium_vij

CONTAINS

! ***************************************************************************
!> \brief  Calculate the helium energy (including helium-solute interaction)
!> \param    helium     helium environment
!> \param    pint_env   path integral environment
!> \par History
!>         2009-06 moved I/O out from here [lwalewski]
!> \author hforbert
! **************************************************************************************************
   SUBROUTINE helium_calc_energy(helium, pint_env)
      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      TYPE(pint_env_type), INTENT(IN)                    :: pint_env

      INTEGER                                            :: b, bead, i, j, n
      INTEGER, DIMENSION(:), POINTER                     :: perm
      LOGICAL                                            :: nperiodic
      REAL(KIND=dp)                                      :: a, cell_size, en, interac, kin, pot, &
                                                            rmax, rmin, vkin
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work2, work3
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: work
      REAL(KIND=dp), DIMENSION(3)                        :: r
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: pos
      TYPE(spline_data_type), POINTER                    :: e0

      pos => helium%pos
      perm => helium%permutation
      e0 => helium%e0
      cell_size = 0.5_dp*helium%cell_size
      nperiodic = .NOT. helium%periodic
      n = helium%atoms
      b = helium%beads
      en = 0.0_dp
      pot = 0.0_dp
      rmin = 1.0e20_dp
      rmax = 0.0_dp
      ALLOCATE (work(3, helium%beads + 1), &
                work2(helium%beads + 1), &
                work3(SIZE(helium%uoffdiag, 1) + 1))
      DO i = 1, n - 1
         DO j = i + 1, n
            DO bead = 1, b
               work(:, bead) = pos(:, i, bead) - pos(:, j, bead)
            END DO
            work(:, b + 1) = pos(:, perm(i), 1) - pos(:, perm(j), 1)
            en = en + helium_eval_chain(helium, work, b + 1, work2, work3, energy=.TRUE.)
            DO bead = 1, b
               a = work2(bead)
               IF (a < rmin) rmin = a
               IF (a > rmax) rmax = a
               IF ((a < cell_size) .OR. nperiodic) THEN
                  pot = pot + helium_spline(helium%vij, a)
               END IF
            END DO
         END DO
      END DO
      DEALLOCATE (work, work2, work3)
      pot = pot/b
      en = en/b

      ! helium-solute interaction energy (all beads of all particles)
      interac = 0.0_dp
      IF (helium%solute_present) THEN
         CALL helium_solute_e(pint_env, helium, interac)
      END IF
      interac = interac/b

!TODO:
      vkin = 0.0_dp
!   vkin = helium_virial_energy(helium)

      kin = 0.0_dp
      DO i = 1, n
         r(:) = pos(:, i, b) - pos(:, perm(i), 1)
         CALL helium_pbc(helium, r)
         kin = kin + r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
         DO bead = 2, b
            r(:) = pos(:, i, bead - 1) - pos(:, i, bead)
            CALL helium_pbc(helium, r)
            kin = kin + r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
         END DO
      END DO
      kin = 1.5_dp*n/helium%tau - 0.5*kin/(b*helium%tau**2*helium%hb2m)

! TODO: move printing somewhere else ?
!   print *,"POT = ",(pot/n+helium%e_corr)*kelvin,"K"
!   print *,"INTERAC = ",interac*kelvin,"K"
!   print *,"RMIN= ",rmin*angstrom,"A"
!   print *,"RMAX= ",rmax*angstrom,"A"
!   print *,"EVIRIAL not valid!"
!   print *,"ETHERMO= ",((en+kin)/n+helium%e_corr)*kelvin,"K"
!   print *,"ECORR= ",helium%e_corr*kelvin,"K"
!!   kin = helium_total_action(helium)
!!   print *,"ACTION= ",kin
!   print *,"WINDING#= ",helium_calc_winding(helium)

      helium%energy_inst(e_id_potential) = pot/n + helium%e_corr
      helium%energy_inst(e_id_kinetic) = (en - pot + kin)/n
      helium%energy_inst(e_id_interact) = interac
      helium%energy_inst(e_id_thermo) = (en + kin)/n + helium%e_corr
      helium%energy_inst(e_id_virial) = vkin ! 0.0_dp at the moment
      helium%energy_inst(e_id_total) = helium%energy_inst(e_id_thermo)
      ! Once vkin is properly implemented, switch to:
      ! helium%energy_inst(e_id_total) = (en+vkin)/n+helium%e_corr

   END SUBROUTINE helium_calc_energy

! ***************************************************************************
!> \brief  Computes the total harmonic link action of the helium
!> \param helium ...
!> \return ...
!> \date   2016-05-03
!> \author Felix Uhl
! **************************************************************************************************
   REAL(KIND=dp) FUNCTION helium_total_link_action(helium) RESULT(linkaction)

      TYPE(helium_solvent_type), INTENT(IN)              :: helium

      INTEGER                                            :: iatom, ibead
      INTEGER, DIMENSION(:), POINTER                     :: perm
      REAL(KIND=dp), DIMENSION(3)                        :: r

      perm => helium%permutation
      linkaction = 0.0_dp

      ! Harmonic Link action
      ! (r(m-1) - r(m))**2/(4*lambda*tau)
      DO ibead = 1, helium%beads - 1
         DO iatom = 1, helium%atoms
            r(:) = helium%pos(:, iatom, ibead) - helium%pos(:, iatom, ibead + 1)
            CALL helium_pbc(helium, r)
            linkaction = linkaction + (r(1)*r(1) + r(2)*r(2) + r(3)*r(3))
         END DO
      END DO
      DO iatom = 1, helium%atoms
         ! choose last bead connection according to permutation table
         r(:) = helium%pos(:, iatom, helium%beads) - helium%pos(:, perm(iatom), 1)
         CALL helium_pbc(helium, r)
         linkaction = linkaction + (r(1)*r(1) + r(2)*r(2) + r(3)*r(3))
      END DO
      linkaction = linkaction/(2.0_dp*helium%tau*helium%hb2m)

   END FUNCTION helium_total_link_action

! ***************************************************************************
!> \brief  Computes the total pair action of the helium
!> \param helium ...
!> \return ...
!> \date   2016-05-03
!> \author Felix Uhl
! **************************************************************************************************
   REAL(KIND=dp) FUNCTION helium_total_pair_action(helium) RESULT(pairaction)

      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium

      INTEGER                                            :: iatom, ibead, jatom, opatom, patom
      INTEGER, DIMENSION(:), POINTER                     :: perm
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work3
      REAL(KIND=dp), DIMENSION(3)                        :: r, rp

      ALLOCATE (work3(SIZE(helium%uoffdiag, 1) + 1))
      perm => helium%permutation
      pairaction = 0.0_dp

      ! He-He pair action
      DO ibead = 1, helium%beads - 1
         DO iatom = 1, helium%atoms - 1
            DO jatom = iatom + 1, helium%atoms
               r(:) = helium%pos(:, iatom, ibead) - helium%pos(:, jatom, ibead)
               rp(:) = helium%pos(:, iatom, ibead + 1) - helium%pos(:, jatom, ibead + 1)
               pairaction = pairaction + helium_eval_expansion(helium, r, rp, work3)
            END DO
         END DO
      END DO
      !Ensure right permutation for pair action of last and first beads.
      DO iatom = 1, helium%atoms - 1
         DO jatom = iatom + 1, helium%atoms
            r(:) = helium%pos(:, iatom, helium%beads) - helium%pos(:, jatom, helium%beads)
            rp(:) = helium%pos(:, perm(iatom), 1) - helium%pos(:, perm(jatom), 1)
            pairaction = pairaction + helium_eval_expansion(helium, r, rp, work3)
         END DO
      END DO

      ! correct for open worm configurations
      IF (.NOT. helium%worm_is_closed) THEN
         ! special treatment if double bead is first bead
         iatom = helium%worm_atom_idx
         IF (helium%worm_bead_idx == 1) THEN
            ! patom is the atom in front of the lone head bead
            patom = helium%iperm(iatom)
            ! go through all atoms
            DO jatom = 1, helium%atoms
               IF (jatom == helium%worm_atom_idx) CYCLE
               opatom = helium%iperm(jatom)
               ! subtract pair action for closed link
               r(:) = helium%pos(:, iatom, 1) - helium%pos(:, jatom, 1)
               rp(:) = helium%pos(:, patom, helium%beads) - helium%pos(:, opatom, helium%beads)
               pairaction = pairaction - helium_eval_expansion(helium, r, rp, work3)
               ! and add corrected extra link
               ! rp stays the same
               r(:) = helium%worm_xtra_bead(:) - helium%pos(:, jatom, 1)
               pairaction = pairaction + helium_eval_expansion(helium, r, rp, work3)
            END DO
         ELSE
            ! bead stays constant
            ibead = helium%worm_bead_idx
            ! go through all atoms
            DO jatom = 1, helium%atoms
               IF (jatom == helium%worm_atom_idx) CYCLE
               ! subtract pair action for closed link
               r(:) = helium%pos(:, iatom, ibead) - helium%pos(:, jatom, ibead)
               rp(:) = helium%pos(:, iatom, ibead - 1) - helium%pos(:, jatom, ibead - 1)
               pairaction = pairaction - helium_eval_expansion(helium, r, rp, work3)
               ! and add corrected extra link
               ! rp stays the same
               r(:) = helium%worm_xtra_bead(:) - helium%pos(:, jatom, ibead)
               pairaction = pairaction + helium_eval_expansion(helium, r, rp, work3)
            END DO
         END IF
      END IF
      DEALLOCATE (work3)

   END FUNCTION helium_total_pair_action

! ***************************************************************************
!> \brief  Computes the total interaction of the helium with the solute
!> \param pint_env ...
!> \param helium ...
!> \return ...
!> \date   2016-05-03
!> \author Felix Uhl
! **************************************************************************************************
   REAL(KIND=dp) FUNCTION helium_total_inter_action(pint_env, helium) RESULT(interaction)

      TYPE(pint_env_type), INTENT(IN)                    :: pint_env
      TYPE(helium_solvent_type), INTENT(IN)              :: helium

      INTEGER                                            :: iatom, ibead
      REAL(KIND=dp)                                      :: e

      interaction = 0.0_dp

      ! InterAction with solute
      IF (helium%solute_present) THEN
         DO ibead = 1, helium%beads
            DO iatom = 1, helium%atoms

               CALL helium_bead_solute_e_f(pint_env, helium, &
                                           iatom, ibead, helium%pos(:, iatom, ibead), e)
               interaction = interaction + e
            END DO
         END DO
         IF (helium%sampling_method == helium_sampling_worm) THEN
            IF (.NOT. helium%worm_is_closed) THEN
               ! subtract half of tail bead interaction again
               CALL helium_bead_solute_e_f(pint_env, helium, &
                                           helium%worm_atom_idx, helium%worm_bead_idx, &
                                           helium%pos(:, helium%worm_atom_idx, helium%worm_bead_idx), e)
               interaction = interaction - 0.5_dp*e
               ! add half of head bead interaction
               CALL helium_bead_solute_e_f(pint_env, helium, &
                                           helium%worm_atom_idx, helium%worm_bead_idx, &
                                           helium%worm_xtra_bead, e)
               interaction = interaction + 0.5_dp*e
            END IF
         END IF
      END IF

      interaction = interaction*helium%tau

   END FUNCTION helium_total_inter_action

! ***************************************************************************
!> \brief Calculate general helium-solute interaction energy (and forces)
!>        between one helium bead and the corresponding solute time slice.
!> \param pint_env           path integral environment
!> \param helium ...
!> \param helium_part_index  helium particle index
!> \param helium_slice_index helium time slice index
!> \param helium_r_opt       explicit helium bead coordinates (optional)
!> \param energy             calculated energy
!> \param force              calculated force (if requested)
!> \par History
!>         2019-09 Added multiple-time striding in imag. time [cschran]
!>         2023-07-23 Modified to work with NNP solute-solvent interactions [lduran]
!> \author Lukasz Walewski
! **************************************************************************************************
   SUBROUTINE helium_bead_solute_e_f(pint_env, helium, helium_part_index, &
                                     helium_slice_index, helium_r_opt, energy, force)

      TYPE(pint_env_type), INTENT(IN)                    :: pint_env
      TYPE(helium_solvent_type), INTENT(IN)              :: helium
      INTEGER, INTENT(IN)                                :: helium_part_index, helium_slice_index
      REAL(KIND=dp), DIMENSION(3), INTENT(IN), OPTIONAL  :: helium_r_opt
      REAL(KIND=dp), INTENT(OUT)                         :: energy
      REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT), &
         OPTIONAL, POINTER                               :: force

      INTEGER                                            :: hbeads, hi, qi, stride
      REAL(KIND=dp), DIMENSION(3)                        :: helium_r
      REAL(KIND=dp), DIMENSION(:), POINTER               :: my_force

      hbeads = helium%beads
      ! helium bead index that is invariant wrt the rotations
      hi = MOD(helium_slice_index - 1 + hbeads + helium%relrot, hbeads) + 1
      ! solute bead index that belongs to hi helium index
      qi = ((hi - 1)*pint_env%p)/hbeads + 1

      ! coordinates of the helium bead
      IF (PRESENT(helium_r_opt)) THEN
         helium_r(:) = helium_r_opt(:)
      ELSE
         helium_r(:) = helium%pos(:, helium_part_index, helium_slice_index)
      END IF

      SELECT CASE (helium%solute_interaction)

      CASE (helium_solute_intpot_mwater)
         IF (PRESENT(force)) THEN
            force(:, :) = 0.0_dp
            my_force => force(qi, :)
            CALL helium_intpot_model_water( &
               pint_env%x(qi, :), &
               helium, &
               helium_r, &
               energy, &
               my_force &
               )
         ELSE
            CALL helium_intpot_model_water( &
               pint_env%x(qi, :), &
               helium, &
               helium_r, &
               energy &
               )
         END IF

      CASE (helium_solute_intpot_nnp)
         IF (PRESENT(force)) THEN
            force(:, :) = 0.0_dp
            my_force => force(qi, :)
            CALL helium_intpot_nnp( &
               pint_env%x(qi, :), &
               helium, &
               helium_r, &
               energy, &
               my_force &
               )
         ELSE
            CALL helium_intpot_nnp( &
               pint_env%x(qi, :), &
               helium, &
               helium_r, &
               energy &
               )
         END IF

      CASE (helium_solute_intpot_none)
         energy = 0.0_dp
         IF (PRESENT(force)) THEN
            force(:, :) = 0.0_dp
         END IF

      CASE DEFAULT

      END SELECT

      ! Account for Imaginary time striding in forces:
      IF (PRESENT(force)) THEN
         IF (hbeads < pint_env%p) THEN
            stride = pint_env%p/hbeads
            force = force*REAL(stride, dp)
         END IF
      END IF

   END SUBROUTINE helium_bead_solute_e_f

! ***************************************************************************
!> \brief Calculate total helium-solute interaction energy and forces.
!> \param   pint_env   path integral environment
!> \param helium ...
!> \param   energy     calculated interaction energy
!> \author Lukasz Walewski
! **************************************************************************************************
   SUBROUTINE helium_solute_e_f(pint_env, helium, energy)

      TYPE(pint_env_type), INTENT(IN)                    :: pint_env
      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      REAL(KIND=dp), INTENT(OUT)                         :: energy

      INTEGER                                            :: ia, ib, jb, jc
      REAL(KIND=dp)                                      :: my_energy
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: force

      NULLIFY (force)
      force => helium%force_inst

      energy = 0.0_dp
      force(:, :) = 0.0_dp

      ! calculate the total interaction energy and gradients between the
      ! solute and the helium, sum over all beads of all He particles
      DO ia = 1, helium%atoms
         DO ib = 1, helium%beads
            CALL helium_bead_solute_e_f(pint_env, helium, ia, ib, &
                                        energy=my_energy, force=helium%rtmp_p_ndim_2d)
            energy = energy + my_energy
            DO jb = 1, pint_env%p
               DO jc = 1, pint_env%ndim
                  force(jb, jc) = force(jb, jc) + helium%rtmp_p_ndim_2d(jb, jc)
               END DO
            END DO
         END DO
      END DO

   END SUBROUTINE helium_solute_e_f

! ***************************************************************************
!> \brief Calculate total helium-solute interaction energy.
!> \param   pint_env   path integral environment
!> \param helium ...
!> \param   energy     calculated interaction energy
!> \author Lukasz Walewski
! **************************************************************************************************
   SUBROUTINE helium_solute_e(pint_env, helium, energy)

      TYPE(pint_env_type), INTENT(IN)                    :: pint_env
      TYPE(helium_solvent_type), INTENT(IN)              :: helium
      REAL(KIND=dp), INTENT(OUT)                         :: energy

      INTEGER                                            :: ia, ib
      REAL(KIND=dp)                                      :: my_energy

      energy = 0.0_dp

      DO ia = 1, helium%atoms
         DO ib = 1, helium%beads
            CALL helium_bead_solute_e_f(pint_env, helium, ia, ib, energy=my_energy)
            energy = energy + my_energy
         END DO
      END DO

   END SUBROUTINE helium_solute_e

! ***************************************************************************
!> \brief  Scan the helium-solute interaction energy within the periodic cell
!> \param pint_env ...
!> \param helium_env ...
!> \date   2014-01-22
!> \par    History
!>         2016-07-14 Modified to work with independent helium_env [cschran]
!> \author Lukasz Walewski
! **************************************************************************************************
   SUBROUTINE helium_intpot_scan(pint_env, helium_env)

      TYPE(pint_env_type), INTENT(IN)                    :: pint_env
      TYPE(helium_solvent_p_type), DIMENSION(:), POINTER :: helium_env

      CHARACTER(len=*), PARAMETER :: routineN = 'helium_intpot_scan'

      INTEGER                                            :: handle, ic, ix, iy, iz, k, nbin
      LOGICAL                                            :: wrapped
      REAL(KIND=dp)                                      :: delr, my_en, ox, oy, oz
      REAL(kind=dp), DIMENSION(3)                        :: pbc1, pbc2, pos

      CALL timeset(routineN, handle)

      ! Perform scan only on ionode, since this is only used to output the intpot
      IF (pint_env%logger%para_env%is_source()) THEN
         ! Assume ionode always to have at least one helium_env
         k = 1
         helium_env(k)%helium%rho_inst(1, :, :, :) = 0.0_dp
         nbin = helium_env(k)%helium%rho_nbin
         delr = helium_env(k)%helium%rho_delr
         helium_env(k)%helium%center(:) = 0.0_dp
         ox = helium_env(k)%helium%center(1) - helium_env(k)%helium%rho_maxr/2.0_dp
         oy = helium_env(k)%helium%center(2) - helium_env(k)%helium%rho_maxr/2.0_dp
         oz = helium_env(k)%helium%center(3) - helium_env(k)%helium%rho_maxr/2.0_dp

         DO ix = 1, nbin
            DO iy = 1, nbin
               DO iz = 1, nbin

                  ! put the probe in the center of the current voxel
                  pos(:) = [ox + (ix - 0.5_dp)*delr, oy + (iy - 0.5_dp)*delr, oz + (iz - 0.5_dp)*delr]

                  ! calc interaction energy for the current probe position
                  helium_env(k)%helium%pos(:, 1, 1) = pos(:)
                  CALL helium_bead_solute_e_f(pint_env, helium_env(k)%helium, 1, 1, energy=my_en)

                  ! check if the probe fits within the unit cell
                  pbc1(:) = pos(:) - helium_env(k)%helium%center
                  pbc2(:) = pbc1(:)
                  CALL helium_pbc(helium_env(k)%helium, pbc2)
                  wrapped = .FALSE.
                  DO ic = 1, 3
                     IF (ABS(pbc1(ic) - pbc2(ic)) > 10.0_dp*EPSILON(0.0_dp)) THEN
                        wrapped = .TRUE.
                     END IF
                  END DO

                  ! set the interaction energy value
                  IF (wrapped) THEN
                     helium_env(k)%helium%rho_inst(1, ix, iy, iz) = 0.0_dp
                  ELSE
                     helium_env(k)%helium%rho_inst(1, ix, iy, iz) = my_en
                  END IF

               END DO
            END DO
         END DO
      END IF

      CALL timestop(handle)
   END SUBROUTINE helium_intpot_scan

! ***************************************************************************
!> \brief Calculate model helium-solute interaction energy and forces
!>        between one helium bead and the corresponding solute time
!>        slice asuming water solute.
!> \param solute_x  solute positions ARR(3*NATOMS)
!>        to global atom indices
!> \param helium    only needed for helium_pbc call at the moment
!> \param helium_x  helium bead position ARR(3)
!> \param energy    calculated interaction energy
!> \param force ...
!> \author Felix Uhl
! **************************************************************************************************
   SUBROUTINE helium_intpot_model_water(solute_x, helium, helium_x, energy, force)

      REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: solute_x
      TYPE(helium_solvent_type), INTENT(IN)              :: helium
      REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: helium_x
      REAL(KIND=dp), INTENT(OUT)                         :: energy
      REAL(KIND=dp), DIMENSION(:), INTENT(INOUT), &
         OPTIONAL, POINTER                               :: force

      INTEGER                                            :: i, ig
      REAL(KIND=dp)                                      :: d, d2, dd, ep, eps, s1, s2, sig
      REAL(KIND=dp), DIMENSION(3)                        :: dr, solute_r

      energy = 0.0_dp
      IF (PRESENT(force)) THEN
         force(:) = 0.0_dp
      END IF

      sig = 2.69_dp ! 1.4 Angstrom
      eps = 60.61e-6_dp ! 19 K
      s1 = 0.0_dp
      DO i = 1, SIZE(helium%solute_element)
         IF (helium%solute_element(i) == "H ") THEN
            ig = i - 1
            solute_r(1) = solute_x(3*ig + 1)
            solute_r(2) = solute_x(3*ig + 2)
            solute_r(3) = solute_x(3*ig + 3)
            dr(:) = solute_r(:) - helium_x(:)
            CALL helium_pbc(helium, dr)
            d2 = dr(1)*dr(1) + dr(2)*dr(2) + dr(3)*dr(3)
            d = SQRT(d2)
            dd = (sig/d)**6
            ep = 4.0_dp*eps*dd*(dd - 1.0_dp)
            s1 = s1 + ep
            s2 = 24.0_dp*eps*dd*(2.0_dp*dd - 1.0_dp)/d2
            IF (PRESENT(force)) THEN
               force(3*ig + 1) = force(3*ig + 1) + s2*dr(1)
               force(3*ig + 2) = force(3*ig + 2) + s2*dr(2)
               force(3*ig + 3) = force(3*ig + 3) + s2*dr(3)
            END IF
         END IF
      END DO ! i = 1, num_hydrogen
      energy = energy + s1

      sig = 5.01_dp ! 2.6 Angstrom
      eps = 104.5e-6_dp ! 33 K
      s1 = 0.0_dp
      DO i = 1, SIZE(helium%solute_element)
         IF (helium%solute_element(i) == "O ") THEN
            ig = i - 1
            solute_r(1) = solute_x(3*ig + 1)
            solute_r(2) = solute_x(3*ig + 2)
            solute_r(3) = solute_x(3*ig + 3)
            dr(:) = solute_r(:) - helium_x(:)
            CALL helium_pbc(helium, dr)
            d2 = dr(1)*dr(1) + dr(2)*dr(2) + dr(3)*dr(3)
            d = SQRT(d2)
            dd = (sig/d)**6
            ep = 4.0_dp*eps*dd*(dd - 1.0_dp)
            s1 = s1 + ep
            s2 = 24.0_dp*eps*dd*(2.0_dp*dd - 1.0_dp)/d2
            IF (PRESENT(force)) THEN
               force(3*ig + 1) = force(3*ig + 1) + s2*dr(1)
               force(3*ig + 2) = force(3*ig + 2) + s2*dr(2)
               force(3*ig + 3) = force(3*ig + 3) + s2*dr(3)
            END IF
         END IF
      END DO ! i = 1, num_chlorine
      energy = energy + s1

   END SUBROUTINE helium_intpot_model_water

! ***************************************************************************
!> \brief  Calculate helium-solute interaction energy and forces between one
!>         helium bead and the corresponding solute time slice using NNP.
!> \param  solute_x  solute positions ARR(3*NATOMS)
!>         to global atom indices
!> \param  helium    only needed for helium_pbc call at the moment
!> \param  helium_x  helium bead position ARR(3)
!> \param  energy    calculated interaction energy
!> \param  force     (optional) calculated force
!> \date   2023-02-22
!> \author Laura Duran
! **************************************************************************************************
   SUBROUTINE helium_intpot_nnp(solute_x, helium, helium_x, energy, force)

      REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: solute_x
      TYPE(helium_solvent_type), INTENT(IN)              :: helium
      REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: helium_x
      REAL(KIND=dp), INTENT(OUT)                         :: energy
      REAL(KIND=dp), DIMENSION(:), INTENT(INOUT), &
         OPTIONAL, POINTER                               :: force

      INTEGER                                            :: i, i_com, ig, ind, ind_he, j, k, m
      LOGICAL                                            :: extrapolate
      REAL(KIND=dp)                                      :: rsqr, rvect(3), threshold
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: denergydsym
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: dsymdxyz
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(nnp_type), POINTER                            :: nnp
      TYPE(section_vals_type), POINTER                   :: print_section

      NULLIFY (logger)
      logger => cp_get_default_logger()

      IF (PRESENT(force)) THEN
         helium%nnp%myforce(:, :, :) = 0.0_dp
      END IF

      extrapolate = .FALSE.
      threshold = 0.0001d0

      !fill coord array
      ig = 1
      DO i = 1, helium%nnp%n_ele
         IF (helium%nnp%ele(i) == 'He') THEN
            ind_he = ig
            DO m = 1, 3
               helium%nnp%coord(m, ig) = helium_x(m)
            END DO
            ig = ig + 1
         END IF
         DO j = 1, helium%solute_atoms
            IF (helium%nnp%ele(i) == helium%solute_element(j)) THEN
               DO m = 1, 3
                  helium%nnp%coord(m, ig) = solute_x(3*(j - 1) + m)
               END DO
               ig = ig + 1
            END IF
         END DO
      END DO

      ! check for hard core condition
      IF (ASSOCIATED(helium%nnp_sr_cut)) THEN
         DO i = 1, helium%nnp%num_atoms
            IF (i == ind_he) CYCLE
            rvect(:) = helium%nnp%coord(:, i) - helium%nnp%coord(:, ind_he)
            CALL helium_pbc(helium, rvect)
            rsqr = rvect(1)*rvect(1) + rvect(2)*rvect(2) + rvect(3)*rvect(3)
            IF (rsqr < helium%nnp_sr_cut(helium%nnp%ele_ind(i))) THEN
               energy = 0.3_dp + 1.0_dp/rsqr
               IF (PRESENT(force)) THEN
                  force = 0.0_dp
               END IF
               RETURN
            END IF
         END DO
      END IF

      ! reset flag if there's an extrapolation to report:
      helium%nnp%output_expol = .FALSE.

      ! calc atomic contribution to energy and force
!NOTE corresponds to nnp_force line with parallelization:
!DO i = istart, istart + mecalc - 1
      DO i = 1, helium%nnp%num_atoms

         !determine index of atom type
         ind = helium%nnp%ele_ind(i)

         !reset input nodes and grads of ele(ind):
         helium%nnp%arc(ind)%layer(1)%node(:) = 0.0_dp
         nnp => helium%nnp ! work around wrong INTENT of nnp_calc_acsf
         IF (PRESENT(force)) THEN
            helium%nnp%arc(ind)%layer(1)%node_grad(:) = 0.0_dp
            ALLOCATE (dsymdxyz(3, helium%nnp%arc(ind)%n_nodes(1), helium%nnp%num_atoms))
            ALLOCATE (denergydsym(helium%nnp%arc(ind)%n_nodes(1)))
            dsymdxyz(:, :, :) = 0.0_dp
            CALL nnp_calc_acsf(nnp, i, dsymdxyz)
         ELSE
            CALL nnp_calc_acsf(nnp, i)
         END IF

         ! input nodes filled, perform prediction:
         DO i_com = 1, helium%nnp%n_committee !loop over committee members
            ! Predict energy
            CALL nnp_predict(helium%nnp%arc(ind), helium%nnp, i_com)
            helium%nnp%atomic_energy(i, i_com) = helium%nnp%arc(ind)%layer(helium%nnp%n_layer)%node(1) ! + helium%nnp%atom_energies(ind)

            !Gradients
            IF (PRESENT(force)) THEN

               denergydsym(:) = 0.0_dp

               CALL nnp_gradients(helium%nnp%arc(ind), helium%nnp, i_com, denergydsym)
               DO j = 1, helium%nnp%arc(ind)%n_nodes(1)
                  DO k = 1, helium%nnp%num_atoms
                     DO m = 1, 3
                        helium%nnp%myforce(m, k, i_com) = helium%nnp%myforce(m, k, i_com) &
                                                          - denergydsym(j)*dsymdxyz(m, j, k)
                     END DO
                  END DO
               END DO

            END IF
         END DO ! end loop over committee members

         !deallocate memory
         IF (PRESENT(force)) THEN
            DEALLOCATE (denergydsym)
            DEALLOCATE (dsymdxyz)
         END IF

      END DO ! end loop over num_atoms

      ! calculate energy:
      helium%nnp%committee_energy(:) = SUM(helium%nnp%atomic_energy, 1)
      energy = SUM(helium%nnp%committee_energy)/REAL(helium%nnp%n_committee, dp)
      helium%nnp%nnp_potential_energy = energy

      IF (PRESENT(force)) THEN
         ! bring myforce to force array
         DO j = 1, helium%nnp%num_atoms
            DO k = 1, 3
               helium%nnp%committee_forces(k, j, :) = helium%nnp%myforce(k, j, :)
            END DO
         END DO
         helium%nnp%nnp_forces(:, :) = SUM(helium%nnp%committee_forces, DIM=3)/REAL(helium%nnp%n_committee, dp)
         ! project out helium force entry
         ig = 1
         DO j = 1, helium%nnp%num_atoms
            IF (j == ind_he) CYCLE
            DO k = 1, 3
               force(3*(helium%nnp%sort(ig) - 1) + k) = helium%nnp%nnp_forces(k, j)
            END DO
            ig = ig + 1
         END DO
      END IF

      ! print properties if requested
      print_section => section_vals_get_subs_vals(helium%nnp%nnp_input, "PRINT")
      CALL helium_nnp_print(helium%nnp, print_section, ind_he)

      RETURN

   END SUBROUTINE helium_intpot_nnp

! ***************************************************************************
!> \brief Helium-helium pair interaction potential.
!> \param r ...
!> \return ...
! **************************************************************************************************
   ELEMENTAL FUNCTION helium_vij(r) RESULT(vij)

      REAL(kind=dp), INTENT(IN)                          :: r
      REAL(kind=dp)                                      :: vij

      REAL(kind=dp)                                      :: f, x, x2

      x = angstrom*r/2.9673_dp
      IF (x < 1.241314_dp) THEN
         x2 = 1.241314_dp/x - 1.0_dp
         f = EXP(-x2*x2)
      ELSE
         f = 1.0_dp
      END IF
      x2 = 1.0_dp/(x*x)
      vij = 10.8_dp/kelvin*(544850.4_dp*EXP(-13.353384_dp*x) - f* &
                            ((0.1781_dp*x2 + 0.4253785_dp)*x2 + 1.3732412_dp)*x2*x2*x2)
   END FUNCTION helium_vij

#if 0

   ! this block is currently turned off

! ***************************************************************************
!> \brief Helium-helium pair interaction potential's derivative.
!> \param r ...
!> \return ...
! **************************************************************************************************
   ELEMENTAL FUNCTION helium_d_vij(r) RESULT(dvij)

      REAL(kind=dp), INTENT(IN)                          :: r
      REAL(kind=dp)                                      :: dvij

      REAL(kind=dp)                                      :: f, fp, x, x2, y

      x = angstrom*r/2.9673_dp
      x = r/2.9673_dp
      x2 = 1.0_dp/(x*x)
      IF (x < 1.241314_dp) THEN
         y = 1.241314_dp/x - 1.0_dp
         f = EXP(-y*y)
         fp = 2.0_dp*1.241314_dp*f*y* &
              ((0.1781_dp*x2 + 0.4253785_dp)*x2 + 1.3732412_dp)*x2*x2*x2*x2
      ELSE
         f = 1.0_dp
         fp = 0.0_dp
      END IF

      dvij = angstrom*(10.8_dp/2.9673_dp)*( &
             (-13.353384_dp*544850.4_dp)*EXP(-13.353384_dp*x) - fp + &
             f*(((10.0_dp*0.1781_dp)*x2 + (8.0_dp*0.4253785_dp))*x2 + (6.0_dp*1.3732412_dp))* &
             x2*x2*x2/x)/(r*kelvin)
   END FUNCTION helium_d_vij

! **************************************************************************************************
!> \brief ...
!> \param helium ...
!> \param n ...
!> \param i ...
!> \return ...
! **************************************************************************************************
   FUNCTION helium_atom_action(helium, n, i) RESULT(res)

      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      INTEGER, INTENT(IN)                                :: n, i
      REAL(KIND=dp)                                      :: res

      INTEGER                                            :: c, j
      REAL(KIND=dp)                                      :: r(3), rp(3), s, t
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work3

      ALLOCATE (work3(SIZE(helium%uoffdiag, 1) + 1))
      s = 0.0_dp
      t = 0.0_dp
      IF (n < helium%beads) THEN
         DO c = 1, 3
            r(c) = helium%pos(c, i, n) - helium%pos(c, i, n + 1)
         END DO
         CALL helium_pbc(helium, r)
         t = r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
         DO j = 1, i - 1
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
               rp(c) = helium%pos(c, i, n + 1) - helium%pos(c, j, n + 1)
            END DO
            s = s + helium_eval_expansion(helium, r, rp, work3)
         END DO
         DO j = i + 1, helium%atoms
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
               rp(c) = helium%pos(c, i, n + 1) - helium%pos(c, j, n + 1)
            END DO
            s = s + helium_eval_expansion(helium, r, rp, work3)
         END DO
      ELSE
         DO c = 1, 3
            r(c) = helium%pos(c, i, n) - helium%pos(c, helium%permutation(i), 1)
         END DO
         CALL helium_pbc(helium, r)
         t = r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
         DO j = 1, i - 1
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
               rp(c) = helium%pos(c, helium%permutation(i), 1) - helium%pos(c, helium%permutation(j), 1)
            END DO
            s = s + helium_eval_expansion(helium, r, rp, work3)
         END DO
         DO j = i + 1, helium%atoms
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
               rp(c) = helium%pos(c, helium%permutation(i), 1) - helium%pos(c, helium%permutation(j), 1)
            END DO
            s = s + helium_eval_expansion(helium, r, rp, work3)
         END DO
      END IF
      t = t/(2.0_dp*helium%tau*helium%hb2m)
      s = s*0.5_dp
      res = s + t
      DEALLOCATE (work3)

   END FUNCTION helium_atom_action

! **************************************************************************************************
!> \brief ...
!> \param helium ...
!> \param n ...
!> \return ...
! **************************************************************************************************
   FUNCTION helium_link_action(helium, n) RESULT(res)

      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      INTEGER, INTENT(IN)                                :: n
      REAL(KIND=dp)                                      :: res

      INTEGER                                            :: c, i, j
      REAL(KIND=dp)                                      :: r(3), rp(3), s, t
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work3

      ALLOCATE (work3(SIZE(helium%uoffdiag, 1) + 1))
      s = 0.0_dp
      t = 0.0_dp
      IF (n < helium%beads) THEN
         DO i = 1, helium%atoms
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, i, n + 1)
            END DO
            CALL helium_pbc(helium, r)
            t = t + r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
            DO j = 1, i - 1
               DO c = 1, 3
                  r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
                  rp(c) = helium%pos(c, i, n + 1) - helium%pos(c, j, n + 1)
               END DO
               s = s + helium_eval_expansion(helium, r, rp, work3)
            END DO
         END DO
      ELSE
         DO i = 1, helium%atoms
            DO c = 1, 3
               r(c) = helium%pos(c, i, n) - helium%pos(c, helium%permutation(i), 1)
            END DO
            CALL helium_pbc(helium, r)
            t = t + r(1)*r(1) + r(2)*r(2) + r(3)*r(3)
            DO j = 1, i - 1
               DO c = 1, 3
                  r(c) = helium%pos(c, i, n) - helium%pos(c, j, n)
                  rp(c) = helium%pos(c, helium%permutation(i), 1) - helium%pos(c, helium%permutation(j), 1)
               END DO
               s = s + helium_eval_expansion(helium, r, rp, work3)
            END DO
         END DO
      END IF
      t = t/(2.0_dp*helium%tau*helium%hb2m)
      res = s + t
      DEALLOCATE (work3)

   END FUNCTION helium_link_action

! **************************************************************************************************
!> \brief ...
!> \param helium ...
!> \return ...
! **************************************************************************************************
   FUNCTION helium_total_action(helium) RESULT(res)

      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      REAL(KIND=dp)                                      :: res

      INTEGER                                            :: i
      REAL(KIND=dp)                                      :: s

      s = 0.0_dp
      DO i = 1, helium%beads
         s = s + helium_link_action(helium, i)
      END DO
      res = s

   END FUNCTION helium_total_action

! **************************************************************************************************
!> \brief ...
!> \param helium ...
!> \param part ...
!> \param ref_bead ...
!> \param delta_bead ...
!> \param d ...
! **************************************************************************************************
   SUBROUTINE helium_delta_pos(helium, part, ref_bead, delta_bead, d)

      TYPE(helium_solvent_type), INTENT(INOUT)           :: helium
      INTEGER, INTENT(IN)                                :: part, ref_bead, delta_bead
      REAL(KIND=dp), DIMENSION(3), INTENT(OUT)           :: d

      INTEGER                                            :: b, bead, db, nbead, np, p
      REAL(KIND=dp), DIMENSION(3)                        :: r

      b = helium%beads

      d(:) = 0.0_dp
      IF (delta_bead > 0) THEN
         bead = ref_bead
         p = part
         db = delta_bead
         DO
            IF (db < 1) EXIT
            nbead = bead + 1
            np = p
            IF (nbead > b) THEN
               nbead = nbead - b
               np = helium%permutation(np)
            END IF
            r(:) = helium%pos(:, p, bead) - helium%pos(:, np, nbead)
            CALL helium_pbc(helium, r)
            d(:) = d(:) + r(:)
            bead = nbead
            p = np
            db = db - 1
         END DO
      ELSEIF (delta_bead < 0) THEN
         bead = ref_bead
         p = part
         db = delta_bead
         DO
            IF (db >= 0) EXIT
            nbead = bead - 1
            np = p
            IF (nbead < 1) THEN
               nbead = nbead + b
               np = helium%iperm(np)
            END IF
            r(:) = helium%pos(:, p, bead) - helium%pos(:, np, nbead)
            CALL helium_pbc(helium, r)
            d(:) = d(:) + r(:)
            bead = nbead
            p = np
            db = db + 1
         END DO
      END IF
   END SUBROUTINE helium_delta_pos

#endif

END MODULE helium_interactions
